'''
Author: SlytherinGe
LastEditTime: 2021-02-24 13:35:35
'''
# torch
import torch
import torch.utils.data as data
# utility
import numpy as np
import cv2
# system
import os
import os.path as osp
import glob as glob
import random

CLASSES = ['Cargo', 'Tanker', 'Other Type', 'Tug', 'Fishing', 'Search', 'Passenger',
        'Dredging', 'Pilot Vessel', 'Port Tender', 'Wing in ground', 'High speed craft',
        'Law Enforcement', 'Towing', 'Diving ops', 'Pleasure Craft']

map_name_to_index = {
    'Cargo':        0,
    'Tanker':       1,
    'Other Type':   2,
    'Tug':          3,
    'Fishing':      4,
    'Search':       5,
    'Passenger':    6,
    'Dredging':     7,
    'Pilot Vessel': 8,
    'Port Tender':  9,
    'Wing in ground':10,
    'High speed craft':11,
    'Law Enforcement':12,
    'Towing':       13,
    'Diving ops':   14,
    'Pleasure Craft':15
}

class STJUSarShipDataset(data.Dataset):
    def __init__(self, data_root, transform):

        self.img_and_label = []     # img root and its annotation
        self.data_root = data_root
        self.transform = transform
        self.num_classes = len(CLASSES)
        print("reading data info into memory")
        root_folder = os.listdir(self.data_root)
        for folder in root_folder:
            img_folder = os.path.join(self.data_root, folder, 'Patch_Uint8')      
            imgs = os.listdir(img_folder)
            for img in imgs:
                im = img.split('_')
                self.img_and_label.append((os.path.join(img_folder, img), im[1]))   
        print('done reading infos! read {} infos'.format(self.__len__()))

    def __len__(self):
        return(len(self.img_and_label))

    def __getitem__(self, index):
        img_path = self.img_and_label[index][0]
        assert osp.exists(img_path), 'Image path does not exist: {}'.format(img_path)
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)

        if self.transform is not None:
            img = self.transform(img)

        label = torch.tensor(np.zeros(self.num_classes), dtype=torch.long)
        label[map_name_to_index[self.img_and_label[index][1]]] = 1

        return torch.from_numpy(img).permute(2,0,1), label

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.data_root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 
        return fmt_str


if __name__ == '__main__':

    ROOT = r'D:\develop\Deep Learning\DataSet\SJTU\OpenSARShip Total\OpenSARShip_total'

    dataset = STJUSarShipDataset(ROOT, None)

    it = dataset.__getitem__(10)

    print(it)

    


